{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import logging\n",
    "import argparse\n",
    "import subprocess\n",
    "from tensorflowonspark import TFCluster\n",
    "import mnist_dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "reload(logging)\n",
    "logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%I:%M:%S')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument(\"--epochs\", help=\"number of epochs\", type=int, default=1)\n",
    "parser.add_argument(\"--images\", help=\"HDFS path to MNIST images in parallelized format\")\n",
    "parser.add_argument(\"--labels\", help=\"HDFS path to MNIST labels in parallelized format\")\n",
    "parser.add_argument(\"--format\", help=\"example format\", choices=[\"csv\",\"pickle\",\"tfr\"], default=\"csv\")\n",
    "parser.add_argument(\"--model\", help=\"HDFS path to save/load model during train/test\", default=\"mnist_model\")\n",
    "parser.add_argument(\"--readers\", help=\"number of reader/enqueue threads\", type=int, default=1)\n",
    "parser.add_argument(\"--steps\", help=\"maximum number of steps\", type=int, default=500)\n",
    "parser.add_argument(\"--batch_size\", help=\"number of examples per batch\", type=int, default=100)\n",
    "parser.add_argument(\"--mode\", help=\"train|inference\", default=\"train\")\n",
    "parser.add_argument(\"--rdma\", help=\"use rdma connection\", default=False)\n",
    "num_executors = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "#remove existing models if any\n",
    "subprocess.call([\"rm\", \"-rf\", \"mnist_model\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "#verify training images\n",
    "train_images_files = \"csv/train/images\"\n",
    "print(subprocess.check_output([\"ls\", \"-l\", train_images_files]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#verify training labels\n",
    "train_labels_files = \"csv/train/labels\"\n",
    "print(subprocess.check_output([\"ls\", \"-l\", train_labels_files]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "#Parse arguments for training\n",
    "args = parser.parse_args(['--mode', 'train', '--steps', '3000', '--epochs', '5',\n",
    "                          '--images', train_images_files, \n",
    "                          '--labels', train_labels_files])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "#start the cluster for training\n",
    "cluster = TFCluster.run(sc, mnist_dist.map_fun, args, num_executors, 1, True, TFCluster.InputMode.SPARK)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "#Feed data via Spark RDD\n",
    "images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])\n",
    "labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])\n",
    "dataRDD = images.zip(labels)\n",
    "cluster.train(dataRDD, args.epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "cluster.shutdown()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "print(subprocess.check_output([\"ls\", \"-l\", \"mnist_model\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "#verify test images\n",
    "test_images_files = \"csv/test/images\"\n",
    "print(subprocess.check_output([\"ls\", \"-l\", test_images_files]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "#verify test labels\n",
    "test_labels_files = \"csv/test/labels\"\n",
    "print(subprocess.check_output([\"ls\", \"-l\", test_labels_files]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "#Parse arguments for inference\n",
    "args = parser.parse_args(['--mode', 'inference', \n",
    "                          '--images', test_images_files, \n",
    "                          '--labels', test_labels_files])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "#Start the cluster for inference\n",
    "cluster = TFCluster.run(sc, mnist_dist.map_fun, args, num_executors, 1, False, TFCluster.InputMode.SPARK)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "#prepare data as Spark RDD\n",
    "images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])\n",
    "labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])\n",
    "dataRDD = images.zip(labels)\n",
    "#feed data for inference\n",
    "prediction_results = cluster.inference(dataRDD)\n",
    "prediction_results.take(20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "cluster.shutdown()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
